-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batched RoPE kernel #3095
Add batched RoPE kernel #3095
Conversation
@@ -77,6 +77,48 @@ __global__ void rotary_embedding_kernel( | |||
} | |||
} | |||
|
|||
template<typename scalar_t, bool IS_NEOX> | |||
__global__ void batched_rotary_embedding_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kernel is almost exactly the same as rotary_embedding_kernel
and you can make them the same by adding the const int64_t* __restrict__ cos_sin_cache_offsets
(will be a null ptr if it is not set) argument there and then down below, doing
int64_t cos_sin_cache_offset = cos_sin_cache_offsets ? cos_sin_cache_offsets[token_idx] : 0;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cos_sin_cache_offset is passed as a pointer, we don't have a good way to determine if it's empty without auxiliary flag, also we try to avoid runtime branching in kernel code for performance. agreed that these two kernels are pretty much the same so I refactored it to avoid too much code duplication.
Do you have a (micro-)benchmark that shows the difference between batched and non-batched to justify the change? |
@tterrysun Can you add the micro benchmarks so we can measure the performance here? You can put it into |
Benchmarking command:
note that this is simulating serving 4 LoRAs, the more LoRAs served, the bigger the difference between single batch kernel & multiple non-batched kernels, majority of the difference should be from Python side. When serving a single LoRA, they should be equivalent |
benchmarks/kernels/benchmark_rope.py
Outdated
type=int, | ||
choices=[64, 80, 96, 112, 128, 256], | ||
default=128) | ||
parser.add_argument("--rottery-dim", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parser.add_argument("--rottery-dim", | |
parser.add_argument("--rotary-dim", |
benchmarks/kernels/benchmark_rope.py
Outdated
seq_len=args.seq_len, | ||
num_heads=args.num_heads, | ||
head_size=args.head_size, | ||
rotary_dim=args.rottery_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rotary_dim=args.rottery_dim, | |
rotary_dim=args.rotary_dim, |
@@ -158,27 +169,30 @@ def __init__( | |||
max_position_embeddings: int, | |||
base: int, | |||
is_neox_style: bool, | |||
scaling_factor: float, | |||
scaling_factors: List[float], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also take in float
by itself (coerce it into a list inside the init)?
@@ -107,7 +108,9 @@ def _forward( | |||
query_pass = query[..., self.rotary_dim:] | |||
key_pass = key[..., self.rotary_dim:] | |||
|
|||
cos_sin = self.cos_sin_cache[positions] | |||
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use positions.device
rather than positions.get_device()
.
https://pytorch.org/docs/stable/generated/torch.Tensor.get_device.html
Hi I am wondering if this kernel currently used? I don't see changes in model code and not sure where the following PRs are. |
Problem: Currently we need to call rotary embedding kernel for each LoRA request, which makes it very inefficient to serve multiple LoRAs with different context length.
Solution: Add batched rotary embedding kernel. Followup PRs will pipe it through.
Testing: Batched kernel tests. Followup PRs will add e2e tests.